{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ScitaPqhKtuW"
      },
      "source": [
        "##### Copyright 2022 The TensorFlow GNN Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hMqWDc_m6rUC"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JSJANANlu8hp"
      },
      "source": [
        "# Molecular Graph Classification with TF-GNN"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yzRjo2fLu9A1"
      },
      "source": [
        "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
        "  \u003ctd\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/gnn/blob/master/examples/tutorials/log_2022/neurips_student_tfgnn_graph_classification_mutag.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "  \u003ctd\u003e\n",
        "    \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/gnn/blob/main/examples/tutorials/log_2022/neurips_student_tfgnn_graph_classification_mutag.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n",
        "  \u003c/td\u003e\n",
        "\u003c/table\u003e"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IiVSmMbrtCl0"
      },
      "source": [
        "We will demonstrate how to do graph classification with TF-GNN.\n",
        "\n",
        "For this example, we will do molecular property prediction, where each molecule is represented as a graph. Nodes correspond to atoms, and edges represent the bonds between them. This is one of the application areas where GNNs are now the method of choice.\n",
        "\n",
        "We will use the MUTAG dataset, a common dataset from the [TUDatasets](https://chrsmrrs.github.io/datasets/) collection.\n",
        "\n",
        "There are 188 graphs in this dataset, labeled with one of two classes, representing \"their mutagenic effect on a specific gram negative bacterium\". Node features represent the 1-hot encoding of the atom type (0=C, 1=N, 2=O, 3=F, 4=I, 5=Cl, 6=Br). Edge features  represent the bond type using a 1-hot encoding (0=aromatic,1=single, 2=double, 3=triple).\n",
        "\n",
        "Please note that this is an introductory example on homogeneous graphs (one node type, and one edge type). TF-GNN is designed to support heterogeneous graphs as well (multiple node types, and/or multiple edge types)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u55FfUXBs_0u"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sS8ot98DEgzJ"
      },
      "source": [
        "Before Python can `import tensorflow_gnn`, the PIP package [`tensorflow-gnn`](https://pypi.org/project/tensorflow-gnn/) needs to be downloaded and installed."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jShUzy75-L9y"
      },
      "outputs": [],
      "source": [
        "!pip install --pre -q tensorflow-gnn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WlEpI7zrSHhb"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "import tensorflow_gnn as tfgnn\n",
        "from tensorflow_gnn.models import gcn\n",
        "from tensorflow_gnn.models import graph_sage\n",
        "from google.protobuf import text_format\n",
        "import tensorflow_gnn.proto.graph_schema_pb2 as schema_pb2\n",
        "\n",
        "print(f'Running TF-GNN {tfgnn.__version__} with TensorFlow {tf.__version__}.')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CmY2hw5Iqty2"
      },
      "source": [
        "### Download the MUTAG dataset\n",
        "We have created a version of the MUTAG Dataset in TF-GNN's file format to use as an example in this colab.\n",
        "\n",
        "Citation: [Morris, Christopher, et al. Tudataset: A collection of benchmark datasets for learning with graphs. arXiv preprint arXiv:2007.08663. 2020.](https://chrsmrrs.github.io/datasets/)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nLAzXPvLpQgw"
      },
      "outputs": [],
      "source": [
        "# Download and unzip dataset.\n",
        "!wget https://storage.googleapis.com/download.tensorflow.org/data/mutag.zip\n",
        "!unzip mutag.zip"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H4VykVm_qty3"
      },
      "outputs": [],
      "source": [
        "train_path = os.path.join(os.getcwd(), 'mutag', 'train.tfrecords')\n",
        "val_path = os.path.join(os.getcwd(), 'mutag', 'val.tfrecords')\n",
        "!ls -l {train_path} {val_path}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8V1hhu5571vA"
      },
      "source": [
        "### Generate GraphTensor\n",
        "\n",
        "Declare the GraphSchema protocol message which describes the entity and relation types, along with their respective features. Here's a depiction of the graph schema for the mutag dataset:\n",
        "\n",
        "![8RjLQuXw6nCjUgq.png]()\n",
        "\n",
        "\n",
        "Calling [tfgnn.create_graph_spec_from_schema_pb](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/tfgnn/create_graph_spec_from_schema_pb.md) with the GraphSchema will provide us the GraphTensorSpec. GraphTensorSpec can be used to parse serialized tf.Example protos into the GraphTensor using [tfgnn.parse_single_example](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/api_docs/python/tfgnn/parse_single_example.md) api with the tf.data.dataset's map function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "REBVKng0FOba"
      },
      "outputs": [],
      "source": [
        "# @title Declare GraphSchema { vertical-output: true }\n",
        "schema_pbtx = \"\"\"\n",
        "# proto-file: //third_party/py/tensorflow_gnn/proto/graph_schema.proto\n",
        "# proto-message: tensorflow_gnn.GraphSchema\n",
        "context {\n",
        "  features {\n",
        "    key: \"label\"\n",
        "    value: {\n",
        "      description: \"compound mutagenicity.\"\n",
        "      dtype: DT_INT32\n",
        "    }\n",
        "  }\n",
        "}\n",
        "node_sets {\n",
        "  key: \"atoms\"\n",
        "  value {\n",
        "    features {\n",
        "      key: \"hidden_state\"\n",
        "      value {\n",
        "        description: \"atom type.\"\n",
        "        dtype: DT_FLOAT\n",
        "        shape { dim { size: 7 } }\n",
        "      }\n",
        "    }\n",
        "  }\n",
        "}\n",
        "edge_sets {\n",
        "  key: \"bonds\"\n",
        "  value {\n",
        "    source: \"atoms\"\n",
        "    target: \"atoms\"\n",
        "    features {\n",
        "      key: \"hidden_state\"\n",
        "      value {\n",
        "        description: \"bond type.\"\n",
        "        dtype: DT_FLOAT\n",
        "        shape { dim { size: 4 } }\n",
        "      }\n",
        "    }\n",
        "  }\n",
        "}\n",
        "\"\"\"\n",
        "graph_schema = text_format.Merge(schema_pbtx, schema_pb2.GraphSchema())\n",
        "print(f\"graph_schema: {graph_schema}\")\n",
        "graph_tensor_spec = tfgnn.create_graph_spec_from_schema_pb(graph_schema)\n",
        "def decode_fn(record_bytes):  \n",
        "  graph = tfgnn.parse_single_example(\n",
        "      graph_tensor_spec, record_bytes, validate=True)\n",
        "\n",
        "  # extract label from context and remove from input graph\n",
        "  context_features = graph.context.get_features_dict()\n",
        "  label = context_features.pop('label')\n",
        "  new_graph = graph.replace_features(context=context_features)\n",
        "\n",
        "  return new_graph, label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9VMiHec0V4BP"
      },
      "outputs": [],
      "source": [
        "#@title Read Datasets\n",
        "train_ds = tf.data.TFRecordDataset([train_path]).map(decode_fn)\n",
        "val_ds = tf.data.TFRecordDataset([val_path]).map(decode_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-n3NZ-Tk-Woh"
      },
      "source": [
        "### Look at one example from the dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "J0BjC-y6-asJ"
      },
      "outputs": [],
      "source": [
        "g, y = train_ds.take(1).get_single_element()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ls-prW6QC27_"
      },
      "source": [
        "#### Node features\n",
        "\n",
        "Node features represent the 1-hot encoding of the atom type (0=C, 1=N, 2=O, 3=F,\n",
        "4=I, 5=Cl, 6=Br)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "39AId0POBf1v"
      },
      "outputs": [],
      "source": [
        "g.node_sets['atoms']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s8JabcxqC0ja"
      },
      "outputs": [],
      "source": [
        "g.node_sets['atoms'].features[tfgnn.HIDDEN_STATE]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2QCQkE8RDYOX"
      },
      "source": [
        "#### Bond Edges\n",
        "\n",
        "In this example, we consider the bonds between atoms as undirected edges. To encode them in the GraphTensor, we store the undirected edges as pairs of directed edges in both directions.\n",
        "\n",
        "`adjacency.source` contains the source node indices, and `adjacency.target` contains the corresponding target node indices."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TxUPaUAz-U7x"
      },
      "outputs": [],
      "source": [
        "g.edge_sets['bonds'].adjacency.source"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hwJM9FoKDXEr"
      },
      "outputs": [],
      "source": [
        "g.edge_sets['bonds'].adjacency.target"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UI1X18IVC-5g"
      },
      "source": [
        "#### Edge features\n",
        "\n",
        "Edge features represent the bond type as one-hot encoding (0=aromatic,1=single,\n",
        "2=double, 3=triple)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a_p-RgX6DS3F"
      },
      "outputs": [],
      "source": [
        "g.edge_sets['bonds'].features[tfgnn.HIDDEN_STATE]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "29bdL2F2gY99"
      },
      "source": [
        "#### Context\n",
        "\n",
        "Context can be used to store graph-level information"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EhNzcyXVgoYb"
      },
      "outputs": [],
      "source": [
        "g.context"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gbeF1etYigfe"
      },
      "source": [
        "### Label\n",
        "The label is binary, indicating the mutagenicity of the molecule. It's either 0 or 1."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LgQGtPlwideV"
      },
      "outputs": [],
      "source": [
        "y"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3lhmmug-DvyX"
      },
      "source": [
        "#### Batch the datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jOUCfdpp_w4v"
      },
      "outputs": [],
      "source": [
        "batch_size = 32 #@param {type:\"integer\"}\n",
        "\n",
        "train_ds_batched = train_ds.batch(batch_size=batch_size).repeat()\n",
        "val_ds_batched = val_ds.batch(batch_size=batch_size)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "63Tqz_SPqyjY"
      },
      "source": [
        "### Build the GNN model\n",
        "\n",
        "TF-GNN provides Keras layers for building graph neural networks. The following code uses Keras' [Functional API](https://www.tensorflow.org/guide/keras/functional) to build a model as a series of GraphTensor transformations, followed by reading out a plain Tensor with the final prediction."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "E9qx2aeJZQMo"
      },
      "source": [
        "## Question!\n",
        "\n",
        "Modify the model code below so that it updates the context feature via the GraphUpdate layer using trainable weights to learn the context feature!\n",
        "\n",
        "After that update the readout layer accordingly!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fEuYYLuVJFVL"
      },
      "outputs": [],
      "source": [
        "# TIP: Check out modeling guide to learn more about GraphUpdate layer: \n",
        "# https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/gnn_modeling.md \n",
        "\n",
        "# This helper function is just a short-hand for the GraphUpdate layers.\n",
        "def dense(units, l2_regularization, dropout_rate, activation=\"relu\"):\n",
        "  \"\"\"A Dense layer with regularization (L2 and Dropout).\"\"\"\n",
        "  regularizer = tf.keras.regularizers.l2(l2_regularization)\n",
        "  return tf.keras.Sequential([\n",
        "      tf.keras.layers.Dense(\n",
        "        units,\n",
        "        activation=activation,\n",
        "        kernel_regularizer=regularizer,\n",
        "        bias_regularizer=regularizer),\n",
        "      tf.keras.layers.Dropout(dropout_rate)\n",
        "    ])\n",
        "\n",
        "def get_graph_update_layer(message_dim, \n",
        "                           next_state_dim, \n",
        "                           l2_regularization=5e-4, \n",
        "                           dropout_rate=0.5):\n",
        "  \"\"\"Function that returns the GraphUpdate layer.\"\"\"\n",
        "  return tfgnn.keras.layers.GraphUpdate(\n",
        "      # TIP: Add a context state to get updated with each GraphUpdate pass.\n",
        "      node_sets={\n",
        "          \"atoms\": tfgnn.keras.layers.NodeSetUpdate(\n",
        "              {\"bonds\": tfgnn.keras.layers.SimpleConv(\n",
        "                  sender_edge_feature=tfgnn.HIDDEN_STATE,\n",
        "                  message_fn=dense(message_dim,\n",
        "                                  l2_regularization,\n",
        "                                  dropout_rate),\n",
        "                  reduce_type=\"sum\",\n",
        "                  receiver_tag=tfgnn.TARGET)},\n",
        "              tfgnn.keras.layers.NextStateFromConcat(dense(\n",
        "                  next_state_dim,\n",
        "                  l2_regularization,\n",
        "                  dropout_rate)))},\n",
        "        )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b7fhATRUexKh"
      },
      "outputs": [],
      "source": [
        "def _build_model(\n",
        "    graph_tensor_spec,\n",
        "    # Dimensions of initial states.\n",
        "    node_dim=16,\n",
        "    edge_dim=16,\n",
        "    # Dimensions for message passing.\n",
        "    message_dim=64,\n",
        "    next_state_dim=64,\n",
        "    # Dimension for the logits.\n",
        "    num_classes=2,\n",
        "    # Number of message passing steps.\n",
        "    num_message_passing=3,\n",
        "    # Other hyperparameters.\n",
        "    l2_regularization=5e-4,\n",
        "    dropout_rate=0.5,\n",
        "):\n",
        "  # Model building with Keras's Functional API starts with an input object\n",
        "  # (a placeholder for the eventual inputs). Here is how it works for\n",
        "  # GraphTensors:\n",
        "  input_graph = tf.keras.layers.Input(type_spec=graph_tensor_spec)\n",
        "\n",
        "  # IMPORTANT: All TF-GNN modeling code assumes a GraphTensor of shape []\n",
        "  # in which the graphs of the input batch have been merged to components of\n",
        "  # one contiguously indexed graph. (There are no edges between components,\n",
        "  # so no information flows between them.)\n",
        "  graph = input_graph.merge_batch_to_components()\n",
        "\n",
        "  # Nodes and edges have one-hot encoded input features. Sending them through\n",
        "  # a Dense layer effectively does a lookup in a trainable embedding table.\n",
        "  def set_initial_node_state(node_set, *, node_set_name):\n",
        "    # Since we only have one node set, we can ignore node_set_name.\n",
        "    return tf.keras.layers.Dense(node_dim)(node_set[tfgnn.HIDDEN_STATE])\n",
        "  def set_initial_edge_state(edge_set, *, edge_set_name):\n",
        "    return tf.keras.layers.Dense(edge_dim)(edge_set[tfgnn.HIDDEN_STATE])\n",
        "\n",
        "  # MapFeatures layer receives callbacks as input for each graph piece: node_set\n",
        "  # edge_set and context. Each callback applies a transformation over the \n",
        "  # existing features of the respective graph piece while using a Keras\n",
        "  # Functional API to call new Keras Layers. For more information and examples \n",
        "  # about the MapFeatures layer please check out its docstring. This call here \n",
        "  # initializes the hidden states of the edge and node sets. \n",
        "  graph = tfgnn.keras.layers.MapFeatures(\n",
        "      node_sets_fn=set_initial_node_state, edge_sets_fn=set_initial_edge_state)(\n",
        "          graph)\n",
        "      \n",
        "  num_atoms = tf.expand_dims(tf.cast(graph.node_sets[\"atoms\"].sizes, dtype=tf.float32), axis=-1)\n",
        "  num_bonds = tf.expand_dims(tf.cast(graph.edge_sets[\"bonds\"].sizes, dtype=tf.float32), axis=-1)\n",
        "  graph = graph.replace_features(\n",
        "      context={\n",
        "          tfgnn.HIDDEN_STATE: tf.concat([num_atoms, num_bonds], axis=1)\n",
        "          })\n",
        " \n",
        "  # This helper function is just a short-hand for the code below.\n",
        "  def dense(units, activation=\"relu\"):\n",
        "    \"\"\"A Dense layer with regularization (L2 and Dropout).\"\"\"\n",
        "    regularizer = tf.keras.regularizers.l2(l2_regularization)\n",
        "    return tf.keras.Sequential([\n",
        "        tf.keras.layers.Dense(\n",
        "            units,\n",
        "            activation=activation,\n",
        "            kernel_regularizer=regularizer,\n",
        "            bias_regularizer=regularizer),\n",
        "        tf.keras.layers.Dropout(dropout_rate)\n",
        "    ])\n",
        "\n",
        "  # The GNN core of the model does `num_message_passing` many updates of node\n",
        "  # states conditioned on their neighbors and the edges connecting to them.\n",
        "  # More precisely:\n",
        "  #  - Each edge computes a message by applying a dense layer `message_fn`\n",
        "  #    to the concatenation of node states of both endpoints (by default)\n",
        "  #    and the edge's own unchanging feature embedding.\n",
        "  #  - Messages are summed up at the common TARGET nodes of edges.\n",
        "  #  - At each node, a dense layer is applied to the concatenation of the old\n",
        "  #    node state with the summed edge inputs to compute the new node state.\n",
        "  # Each iteration of the for-loop creates new Keras Layer objects, so each\n",
        "  # round of updates gets its own trainable variables.\n",
        "  for i in range(num_message_passing):   \n",
        "    graph = tfgnn.keras.layers.GraphUpdate(\n",
        "        node_sets={\n",
        "            \"atoms\": tfgnn.keras.layers.NodeSetUpdate(\n",
        "                {\"bonds\": tfgnn.keras.layers.SimpleConv(\n",
        "                    sender_edge_feature=tfgnn.HIDDEN_STATE,\n",
        "                    message_fn=dense(message_dim),\n",
        "                    reduce_type=\"sum\",\n",
        "                    receiver_tag=tfgnn.TARGET)},\n",
        "                tfgnn.keras.layers.NextStateFromConcat(dense(next_state_dim)))},\n",
        "        )(graph)\n",
        "    \n",
        "\n",
        "  # After the GNN has computed a context-aware representation of the \"atoms\",\n",
        "  # the model reads out a representation for the graph as a whole by averaging\n",
        "  # (pooling) node states into the graph context. The context is global to each\n",
        "  # input graph of the batch, so the first dimension of the result corresponds\n",
        "  # to the batch dimension of the inputs (same as the labels).\n",
        "  readout_features = tfgnn.keras.layers.Pool(\n",
        "      tfgnn.CONTEXT, \"mean\", node_set_name=\"atoms\")(graph)\n",
        "  # Context  has a hidden-state feature, concatenate the aggregated node vectors\n",
        "  # with the hidden-state to get the final vector,\n",
        "  feat = tf.concat([readout_features, graph.context[tfgnn.HIDDEN_STATE]], axis=1)\n",
        "  # Put a linear classifier on top (not followed by dropout).\n",
        "  logits = tf.keras.layers.Dense(1)(feat)\n",
        "\n",
        "  # Build a Keras Model for the transformation from input_graph to logits.\n",
        "  return tf.keras.Model(inputs=[input_graph], outputs=[logits])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sEGSG5ZXaRJ_"
      },
      "source": [
        "Feel free to play with the hyperparameters and the model architecture to improve the results!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PjM39AM2E0TD"
      },
      "source": [
        "#### Define Loss and Metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6IAO1FToa5xk"
      },
      "outputs": [],
      "source": [
        "node_dim = 16 #@param {type:\"integer\"}\n",
        "edge_dim = 16 #@param {type:\"integer\"}\n",
        "message_dim = 64 #@param {type:\"integer\"}\n",
        "next_state_dim = 64 #@param {type:\"integer\"}\n",
        "num_classes = 2 #@param {type:\"integer\"}\n",
        "num_message_passing = 3 #@param {type:\"integer\"}\n",
        "l2_regularization = 5e-4 #@param {type:\"number\"}\n",
        "dropout_rate = 0.5 #@param {type:\"number\"}\n",
        "\n",
        "model_input_graph_spec, label_spec = train_ds.element_spec\n",
        "del label_spec # Unused.\n",
        "model = _build_model(model_input_graph_spec,\n",
        "                     node_dim=node_dim,\n",
        "                     edge_dim=edge_dim,\n",
        "                     message_dim=message_dim,\n",
        "                     next_state_dim=next_state_dim,\n",
        "                     num_classes=num_classes,\n",
        "                     num_message_passing=num_message_passing,\n",
        "                     l2_regularization=l2_regularization,\n",
        "                     dropout_rate=dropout_rate)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Tyxzv1OQIX81"
      },
      "outputs": [],
      "source": [
        "train_ds.element_spec"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VcrTzMEfexIm"
      },
      "outputs": [],
      "source": [
        "loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)\n",
        "metrics = [tf.keras.metrics.BinaryAccuracy(threshold=0.),\n",
        "            tf.keras.metrics.BinaryCrossentropy(from_logits=True)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cREyDPxrFBH-"
      },
      "source": [
        "#### Compile the keras model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1niHXmREKfw9"
      },
      "outputs": [],
      "source": [
        "# @title Optimizer \n",
        "learning_rate = 1e-3 #@param {type:\"number\"}\n",
        "learning_rate_decay = True #@param {type:\"boolean\"}\n",
        "steps_per_epoch = 10 #@param {type:\"integer\"}\n",
        "epochs = 200 #@param {type:\"integer\"}\n",
        "\n",
        "if learning_rate_decay:\n",
        "  learning_rate = tf.keras.optimizers.schedules.CosineDecay(learning_rate, steps_per_epoch*epochs)\n",
        "\n",
        "optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)\n",
        "\n",
        "model.compile(optimizer, loss=loss, metrics=metrics)\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wB5JoJ9pFDGx"
      },
      "source": [
        "#### Train the model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uMotwUlM8U6s"
      },
      "outputs": [],
      "source": [
        "steps_per_epoch = 10 #@param {type:\"integer\"}\n",
        "epochs = 200 #@param {type:\"integer\"}\n",
        "history = model.fit(train_ds_batched,\n",
        "                    steps_per_epoch=steps_per_epoch,\n",
        "                    epochs=epochs,\n",
        "                    validation_data=val_ds_batched)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "coXPPpmUFKAh"
      },
      "source": [
        "### Plot the loss and metric curves for train and val"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0Qd2iuVZpsRH"
      },
      "outputs": [],
      "source": [
        "for k, hist in history.history.items():\n",
        "  plt.plot(hist)\n",
        "  plt.title(k)\n",
        "  plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MhHcIB1kohCC"
      },
      "source": [
        "## GraphTensor Advanced\n",
        "\n",
        "TF-GNN layers work with scalar graph-tensors, merge_batch_to_components() api will convert a graph-tensor with a batch of graphs into a scalar graph where each graph in the batch becomes an unconnected component of the scalar graph."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Rp2tkW9co2OD"
      },
      "source": [
        "### Batch of GraphTensors:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3D7450Mrovyb"
      },
      "outputs": [],
      "source": [
        "batch_size = 5 #@param {type:\"integer\"}\n",
        "\n",
        "gt, y = train_ds.batch(batch_size).take(1).get_single_element()\n",
        "gt"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AoSLfGexqBMa"
      },
      "outputs": [],
      "source": [
        "gt.node_sets[\"atoms\"].features[tfgnn.HIDDEN_STATE]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P7kq6N2siYFb"
      },
      "outputs": [],
      "source": [
        "gt.edge_sets[\"bonds\"].adjacency.source"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yF28VWvTp4Lr"
      },
      "source": [
        "### Merged Scalar GraphTensor"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QaHBOB_yprTx"
      },
      "outputs": [],
      "source": [
        "scalar_gt = gt.merge_batch_to_components()\n",
        "scalar_gt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GH1bNxa_iGlA"
      },
      "source": [
        "Let's look at the node-sets and edge-sets of the scalar graph:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yVVz8B7zqD2f"
      },
      "outputs": [],
      "source": [
        "scalar_gt.node_sets[\"atoms\"].features[tfgnn.HIDDEN_STATE]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jo5MVlt_ikzU"
      },
      "outputs": [],
      "source": [
        "scalar_gt.edge_sets[\"bonds\"].adjacency.source"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2thndLDNfyRt"
      },
      "source": [
        "## Heterogeneos Graph Convolutions\n",
        "\n",
        "In the following subsections you can see a code snippet to transform the graph with a GraphUpdate call over the **M**icrosoft **A**cademic **G**raph ([ogb-mag](https://ogb.stanford.edu/)) dataset with the depicted schema below.  Checkout this [colab](https://colab.sandbox.google.com/github/tensorflow/gnn/blob/master/examples/notebooks/ogbn_mag_e2e.ipynb#scrollTo=b7fhATRUexKh) for the full example.\n",
        "\n",
        " "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8gCLCZO4eeKL"
      },
      "source": [
        "![6V9XQ4k7JF5FvQK.png]()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JZeu4FzCITwP"
      },
      "source": [
        "## GraphUpdate over OGB-MAG dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "k3ZF1rggIV69"
      },
      "source": [
        "![image.png]()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "ScitaPqhKtuW"
      ],
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1NAHzoiKVDqpUnrHYm4EmugesVBSOHc9Y",
          "timestamp": 1669589521986
        },
        {
          "file_id": "https://github.com/tensorflow/gnn/blob/master/examples/tutorials/neurips_2022/neurips_student_tfgnn_graph_classification_mutag.ipynb",
          "timestamp": 1669589412082
        },
        {
          "file_id": "1HIsmwmUYFiNx578dNIQCeq0xPbh-nRo8",
          "timestamp": 1669173297036
        },
        {
          "file_id": "1-zvRMMHctfqx6X7LHHPOOM2Ybv-Dy9A8",
          "timestamp": 1669171796245
        },
        {
          "file_id": "https://github.com/tensorflow/gnn/blob/master/examples/notebooks/intro_mutag_example.ipynb",
          "timestamp": 1668716349772
        },
        {
          "file_id": "199Qj2Kc73yreV9szw3ZvghJb9aG8R8Us",
          "timestamp": 1639468859837
        },
        {
          "file_id": "1Qm-6zA_twOkAsgCoA5Zl2X9RaFa04Zii",
          "timestamp": 1637635073567
        }
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
